import torch.nn.functional as F
import torch
from layer import steps
def single_stage_at_loss(f_s, f_t, p):
    def _at(feat, p):
        return F.normalize(feat.pow(p).mean(1).reshape(feat.size(0), -1))
    return (_at(f_s, p) - _at(f_t, p)).pow(2).mean()

def at_loss(g_s, g_t, p):
    return sum([single_stage_at_loss(f_s, f_t, p) for f_s, f_t in zip(g_s, g_t)])


def poly_kernel(a, b):
    a = a.unsqueeze(1)
    b = b.unsqueeze(2)
    res = (a * b).sum(-1).pow(2)
    return res

def similarity_loss(f_s, f_t, p):
    bsz = f_s.shape[0]
    f_s = f_s.view(bsz, -1)
    f_t = f_t.view(bsz, -1)
    #f_t = relu(f_t)

    G_s = torch.mm(f_s, torch.t(f_s))
    G_s = torch.nn.functional.normalize(G_s)
    G_t = torch.mm(f_t, torch.t(f_t))
    G_t = torch.nn.functional.normalize(G_t)

    loss =  F.smooth_l1_loss(2*torch.exp(G_s), 2*torch.exp(G_t), reduction="mean", beta = 1.0)
    return loss

def sp_loss(g_s, g_t, p):
    return sum([similarity_loss(f_s, f_t, p) for f_s, f_t in zip(g_s, g_t)])

def Loss(student_feature, teacher_feature, target, type, alpha, beta, temperature):
    teacher_output,t1,t2,t3=teacher_feature
    student_output,s1,s2,s3=student_feature

    if type=='KD':
        ce_loss = F.cross_entropy(student_output, target)
        log_pred_student = F.log_softmax(student_output / temperature, dim=1)
        pred_teacher = F.softmax(teacher_output / temperature, dim=1)
        kd_loss = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
        kd_loss *=  temperature**2
        loss = alpha*ce_loss + beta*kd_loss
    elif type=='inter':
        ce_loss = F.cross_entropy(student_output, target)
        s1 = torch.sum(s1, dim=4) / steps
        s2 = torch.sum(s2, dim=4) / steps
        s3 = torch.sum(s3, dim=4) / steps
        AT_loss = at_loss((t1,t2,t3),(s1,s2,s3),2)
        loss = ce_loss + 1000*AT_loss  
    elif type=='relation':
        ce_loss = F.cross_entropy(student_output, target)
        s1 = torch.sum(s1, dim=4) / steps
        s2 = torch.sum(s2, dim=4) / steps
        s3 = torch.sum(s3, dim=4) / steps
        SP_loss = sp_loss((t1,t2,t3),(s1,s2,s3),2)
        loss = ce_loss + 1000*SP_loss    
    return loss 